-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Warn if using tied target module with tie_word_embeddings
#2025
Warn if using tied target module with tie_word_embeddings
#2025
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for creating this PR. I think we need to rethink the approach here, as the current one will not work in all situations.
get_peft_model
is a very generic function and is also used for prompt tuning methods, for instance. Therefore, we cannot assume thatpeft_config.target_modules
exist.- Not all methods allow to merge the weights, thus we should not warn in those cases (false warnings should be avoided as much as possible).
- Even if
peft_config.target_modules
does exist, it could be a string, so looping over it will not always be correct. - As we already observed, it will not work for custom models with tied weights, but let's consider this out of scope for now.
So how can we correctly identify when a warning is needed? My proposal is that this needs to be solved on a different level:
The check if there is a tied target layer needs to live on the corresponding method's model level (e.g. LoraModel
), as only there can we really know which layers are targeted. Thankfully, the models that support merging all inherit from BaseTuner
. There, we have the inject_adapter
method. If you look at this line, you can see that all modules that are actually targeted are stored in self.targeted_module_names
. Therefore, after exiting the loop, we can add a new method that takes this list and checks if any of the keys are tied weights using the logic you proposed.
This new check should be implemented as a new method on the BaseTuner
class, so that subclasses such as LoraModel
may choose to override the method if there ever is a need.
Additionally, I wonder if there should be a warning when the user attempts to merge. One could argue that this is too late, but even at this point, there are workarounds: If the user clones the tied weights, they can merge without affecting the other weight (at the cost of extra memory).
This additional warning could be added to the _check_merge_allowed
method and it could re-use the same method as mentioned above to perform the check. However, the warning message should be a bit different.
I know this is all a bit more complicated that initially thought and not necessarily what you "signed up for". So let me know if you still want to work on this or not, in which case I'll put this on my backlog.
Not at all thanks, sounds really good, I'll have a go! |
Thanks a lot. |
c236129
to
44a02de
Compare
@BenjaminBossan I made a version addressing your suggestions. Also, I refactored getting the model config in the code base.
I feel like the new message can be the same. Let me know. (I can't run the whole test suite as I do not have a cuda-compatible gpu.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the PR.
I feel like the new message can be the same. Let me know.
I think the error message is good as is for when the model is being initialized. When merging, I think we could show a different warning, where we mention that if the weight is cloned beforehand, merging should work, at the cost of higher memory usage.
To implement this, I would change the _warn_if_tied_embeddings_in_target_modules
method from warning to just performing the check and returning a bool (renaming the method accordingly). Then during injection, if the check returns True
, the current warning is given, and during merging, if the check returns True, the adapted warning is given. WDYT?
(I can't run the whole test suite as I do not have a cuda-compatible gpu.)
This is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes. I have a few small suggestions for improvements, please check them out?
It would also be great to add unit tests for this, but probably this will be a bit more complicated. I leave it up to you if you want to give this a try, otherwise I'll work on it in a subsequent PR.
516fc3c
to
3a51e67
Compare
Sure very happy to write tests! I'll put them in Just one question: to mock models with tied embeddings, should I use the test model model = AutoModelForCausalLM.from_pretrained(model_id, tie_word_embeddings=True) |
cd3e830
to
cf4bf3e
Compare
@BenjaminBossan I added the test here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for making the updates, using DUMMY_MODEL_CONFIG
consistently and extending the tests. This looks quite good already, but I have some suggestions for improvements, please chekc.
I just did this as it's a bit unclear if in this case the model_config needs to default to None or if it can be the DUMMY one, let me know!
The change you made looks good as is.
Just one question: to mock models with tied embeddings, should I use the test model "HuggingFaceH4/tiny-random-LlamaForCausalLM" but loaded with:
I didn't know that this was an option. Yes, looks like the right choice.
src/peft/tuners/tuners_utils.py
Outdated
warnings.warn( | ||
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. " | ||
"This can lead to complications when merging the adapter. " | ||
"You can opt to merge the adapter after cloning the weights (to untie the embeddings), " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, I didn't know about the option to pass tie_word_embeddings=False
. Is there even a need to clone the weights in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it works, I added in the warning code to create the untied model.
tests/test_tuners_utils.py
Outdated
config = BaseTuner.get_model_config(ModelWithNoConfig()) | ||
assert config == DUMMY_MODEL_CONFIG | ||
|
||
def test_warn_for_tied_embeddings_inject_and_merge(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding these tests. They are already looking quite good. I think, however, that this last test can be simplified a bit.
As you correctly observed, there are 6 scenarios to test:
- Warning for
get_peft_model
and warning for merging. - Valid warning vs no tied embeddings vs tied embeddings but not targeted.
Instead of cramming those into a single test, let's make this 6 separate tests. It should also be fine to make it 3 tests, where get_peft_model
and merging are checked together. Hopefully, this should make the assert_warning_triggered
function unnecessary.
You probably also had a bit of an issue that unrelated warnings could be recorded. Maybe this can be made simpler by using the recwarn
fixture. Then you can just check that any warning has been recorded with the corresponding message, something like:
assert any(str(warning.message).startswith(msg) for warning in recwarn.list)
tests/test_tuners_utils.py
Outdated
pass | ||
|
||
|
||
class TestBaseTunerMethods(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's split this test class into 2: One for get_model_config
and one for the tied embeddings.
a2f7354
to
7926888
Compare
@BenjaminBossan I think I've address the comments 👍 |
Thanks for the latest updates. I only have one more question, namely when it comes to how to untie the weights. In the script you provide, you clone the weights but is that even necessary if >>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", tie_word_embeddings=False)
>>> model.model.embed_tokens.weight.data_ptr()
126062054867008
>>> model.lm_head.weight.data_ptr() # <= different data ptr
126051845931072
>>> model.model.embed_tokens.weight.sum()
tensor(952564.6250, grad_fn=<SumBackward0>)
>>> model.lm_head.weight.sum()
tensor(255.3427, grad_fn=<SumBackward0>)
>>> from peft import LoraConfig, get_peft_model
>>> config = LoraConfig(init_lora_weights=False, target_modules=["embed_tokens"])
>>> model = get_peft_model(model, config)
>>> unloaded = model.merge_and_unload()
>>> unloaded.model.embed_tokens.weight.sum() # <= embed weights changed
tensor(985655.8125)
>>> unloaded.lm_head.weight.sum() # <= lm head stayed the same
tensor(255.3427) |
Yes I agree with your script but the user wants to fix This cloning also seems to allow to save it correctly. If you do not clone (beside actaully re-tieing the embeddings), then when you load the saved-untied model the last assertion below will fail, otherwise, if you clone, it will pass: model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)
# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data
assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data)
# Save the untied model
untied_model_dir = "tmp_model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)
# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
assert model.model.embed_tokens.weight.data.data_ptr() != model.lm_head.weight.data.data_ptr()
assert torch.equal(model.lm_head.weight.data, model.model.embed_tokens.weight.data) |
Oh wow, I did not know that the LM head will be randomly initialized, that's quite surprising IMO. I would have expected to get the same parameter values, just not tied. Thanks for making me aware of that. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Not sure how to reproduce the error in the git actions ruff check src tests examples docs scripts docker
All checks passed!
ruff format --check src tests examples docs scripts docker
189 files already formatted
doc-builder style src/peft tests docs/source --max_len 119 --check_only
Traceback (most recent call last):
File "/opt/hostedtoolcache/Python/3.8.18/x64/bin/doc-builder", line 8, in <module>
sys.exit(main())
File "/opt/hostedtoolcache/Python/3.8.[18](https://github.com/huggingface/peft/actions/runs/10578337704/job/29356715757?pr=2025#step:5:19)/x64/lib/python3.8/site-packages/doc_builder/commands/doc_builder_cli.py", line 47, in main
args.func(args)
File "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/doc_builder/commands/style.py", line 28, in style_command
raise ValueError(f"{len(changed)} files should be restyled!")
ValueError: 1 files should be restyled!
make: *** [Makefile:11: quality] Error 1
Error: Process completed with exit code 2. |
@ltoniazzi could you please run |
I ran @@ -530,8 +530,8 @@ model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
@staticmethod
def get_model_config(model: nn.Module) -> dict:
"""
- This method gets the config from a model in dictionary form.
- If model has not attribute config, then this method returns a default config.
+ This method gets the config from a model in dictionary form. If model has not attribute config, then this
+ method returns a default config. |
Done! |
tie_word_embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, very nicely done PR. I just have two tiny comments for cosmetic reasons, otherwise this can be merged.
tests/test_tuners_utils.py
Outdated
) | ||
return model | ||
|
||
def _is_warn_triggered(self, rrecwarn, endswith): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you call it rrecwarn
to avoid naming conflicts? If yes, how about just passing the recwarn.list
, which is all we need, and call it warning_list
or so.
src/peft/tuners/tuners_utils.py
Outdated
# Now use the original model but in untied format | ||
model = AutoModelForCausalLM.from_pretrained(untied_model_dir) | ||
``` | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see why you left-aligned the code snippet so that it is nicely printed. But this is really an eye-sore to read in code. Here is a trick to that let's us use the correct indentation but still get a nice warning message by using textwrap.dedent
:
example_code = textwrap.dedent(
"""
```python
from transformers import AutoModelForCausalLM
# Load original tied model
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", tie_word_embeddings=False)
# Set the randomly initialized lm_head to the previously tied embeddings
model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()
# Save the untied model
untied_model_dir = "dir/for/untied/model"
model.save_pretrained(untied_model_dir)
model.config.save_pretrained(untied_model_dir)
# Now use the original model but in untied format
model = AutoModelForCausalLM.from_pretrained(untied_model_dir)
```
"""
)
warnings.warn(
f"Model with `tie_word_embeddings=True` and the {tied_target_modules=} are part of the adapter. "
"This can lead to complications. "
"You can opt to merge the adapter after cloning the weights (to untie the embeddings). "
"You can untie the embeddings by loading the model with `tie_word_embeddings=False`. For example:"
+ example_code
)
The textwrap
module is from the standardlib and needs to be imported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed both thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much, great work, hopefully this will help users in the future to avoid this potential pitfall.
@BenjaminBossan Thanks so much for your help! ❤️ Btw, a test on main failed, do you think it's related to this PR? |
Don't worry, this is a known issue with X-LoRA that came about with a recent change in transformers. |
Context
Solving issue #2018.
target_module
when the embeddings are tied, because this could lead to errors, for example when merging the adapter.Todo
Try if load withtie_word_embeddings=False
is an actual option. Load Gemma2 with finetuned differentlm_weights
and check that the lm_head is not replaced with the embedding (even if cloned). If it works, try to merge an adapter to lm_weight and then load it to check if embed and lm_head are kept separate. (the main concern is that the loading model's architecture might ignore anylm_head
weight present in safetensors, as it happens in llama.cpp for example).